#coding=utf-8

from nltk.tree import Tree
import numpy

train = []
train_set = set([])
dev = []
test = []
total = []

with open('train.txt') as f:
    for line in f:
        tree = Tree(line)
        total.append((' '.join(tree.leaves()), tree.node))
        for subTree in tree.subtrees():
            label = subTree.node
            content = ' '.join(subTree.leaves())
            if content not in train_set:
                train.append((content, label))
                train_set.add(content)

with open('dev.txt') as f:
    for line in f:
        tree = Tree(line)
        label = tree.node
        content = ' '.join(tree.leaves())
        dev.append((content, label))

total.extend(dev)

with open('test.txt') as f:
    for line in f:
        tree = Tree(line)
        label = tree.node
        content = ' '.join(tree.leaves())
        test.append((content, label))

total.extend(test)

numpy.random.shuffle(train)
numpy.random.shuffle(test)
numpy.random.shuffle(dev)

# 五分类
f51_w = open('train5.data', 'w')
f52_w = open('train5.label', 'w')
f53_w = open('test5.data', 'w')
f54_w = open('test5.label', 'w')
f55_w = open('dev5.data', 'w')
f56_w = open('dev5.label', 'w')

# 二分类
f21_w = open('train2.data', 'w')
f22_w = open('train2.label', 'w')
f23_w = open('test2.data', 'w')
f24_w = open('test2.label', 'w')
f25_w = open('dev2.data', 'w')
f26_w = open('dev2.label', 'w')

f_w = open('total.data', 'w')

for sentence, label in total:
    f_w.write(sentence + '\n')

for sentence, label in train:
    f51_w.write(sentence + '\n')
    f52_w.write(label + '\n')
    if label == '0' or label == '1':
        f21_w.write(sentence + '\n')
        f22_w.write('0\n')
    elif label == '3' or label == '4':
        f21_w.write(sentence + '\n')
        f22_w.write('1\n')
f51_w.close()
f52_w.close()
f21_w.close()
f22_w.close()

for sentence, label in test:
    f53_w.write(sentence + '\n')
    f54_w.write(label + '\n')
    if label == '0' or label == '1':
        f23_w.write(sentence + '\n')
        f24_w.write('0\n')
    elif label == '3' or label == '4':
        f23_w.write(sentence + '\n')
        f24_w.write('1\n')
f53_w.close()
f54_w.close()
f23_w.close()
f24_w.close()

for sentence, label in dev:
    f55_w.write(sentence + '\n')
    f56_w.write(label + '\n')
    if label == '0' or label == '1':
        f25_w.write(sentence + '\n')
        f26_w.write('0\n')
    elif label == '3' or label == '4':
        f25_w.write(sentence + '\n')
        f26_w.write('1\n')
f55_w.close()
f56_w.close()
f25_w.close()
f26_w.close()
